

import numpy as np
import numba
from numba import jit, njit


try:
    profile
except:
    def profile(x): 
        return x
# from numba import jitclass          # import the decorator
# from numba import int8, int64, float32, optional    # import the types

@jit(nopython=True, nogil = True)
def set_mean_variance_all(path_scores, rec_rate):
    nLoci, nTraits, tmp = path_scores.shape

    residual_means = np.full((nLoci, nTraits, 2), 0, dtype = np.float32)
    residual_variances = np.full((nLoci, nTraits, 2), 0, dtype = np.float32)

    for index in range(nLoci):
        for hap in range(2):
            set_mean_variance_index(index, hap, residual_means[index,:,hap], residual_variances[index,:,hap], path_scores, rec_rate)

    # print(path_scores[nLoci-2, 0:10, 0])
    # print(residual_means[nLoci-2, 0:10, 0])

    return residual_means, residual_variances

@jit(nopython=True, nogil = True)
def set_mean_variance_index(index, hap, residual_mean, residual_variance, path_scores, rec_rate):

    nLoci, nTraits, tmp = path_scores.shape

    residual_mean[:] = 0
    residual_variance[:] = 0

    if index + 1 < nLoci:
        if hap == 0:
            p_0 = 1
            p_1 = 0
        else:
            p_0 = 0
            p_1 = 1

        for i in range(index + 1, nLoci):
            p_0 = (1-2*rec_rate)*p_0 + rec_rate
            p_1 = (1-2*rec_rate)*p_1 + rec_rate
            p_0 = p_0/(p_0 + p_1) # Just as a double check. I don't think we need these.
            p_1 = p_1/(p_0 + p_1) # Just as a double check. I don't think we need these.
            
            for trait in range(nTraits):
                point_mean = p_0*path_scores[i, trait, 0] + p_1*path_scores[i, trait, 1] 
                point_sq_mean = p_0*path_scores[i, trait, 0]**2 + p_1*path_scores[i, trait, 1]**2

                residual_mean[trait] += point_mean
                residual_variance[trait]  += point_sq_mean - point_mean**2



@jit(nopython=True, nogil = True)
def sample_path(target_phenotype, path_scores, residual_means, residual_variances, rec_rate):
    # Constants
    nLoci = path_scores.shape[0]
    log_rec_rate = np.log(rec_rate)
    log_inv_rec_rate = np.log(1 - rec_rate)



    current_phentoype = np.full(target_phenotype.shape, 0, dtype = np.float32)
    current_path = np.full(nLoci, 0, dtype = np.int8)
    for index in range(nLoci):


        log_p_0 = evaluate_phenotypes(target_phenotype, current_phentoype, path_scores, residual_means, residual_variances, 0, index)
        log_p_1 = evaluate_phenotypes(target_phenotype, current_phentoype, path_scores, residual_means, residual_variances, 1, index)

        if index > 1:
            choice = numba_get_choice(log_p_0, log_p_1, current_path[index-1], log_rec_rate, log_inv_rec_rate)
        else:
            choice = numba_get_choice(log_p_0, log_p_1, -1, log_rec_rate, log_inv_rec_rate)

        phenotype_update = path_scores[index, :, choice]

        numba_add(current_phentoype, phenotype_update) # current_phentoype += phenotype_update
        
        current_path[index] = choice

    return current_path, current_phentoype



@jit(nopython=True, nogil = True)
def numba_get_choice(log_p_0, log_p_1, current, log_rec_rate, log_inv_rec_rate):
    if current != -1:
        if current == 0:
            log_p_0 += log_inv_rec_rate
            log_p_1 += log_rec_rate

        if current == 1:
            log_p_0 += log_rec_rate
            log_p_1 += log_inv_rec_rate
    
    # Un-log transform.

    p_0 = np.exp(log_p_0 - max(log_p_0, log_p_1))
    p_1 = np.exp(log_p_1 - max(log_p_0, log_p_1))

    p_0 = p_0/(p_0 + p_1)
    p_1 = p_1/(p_0 + p_1)

    rand = np.random.random()
    if rand < p_0:
        choice = 0
    else:
        choice = 1

    return choice

@jit(nopython=True, nogil = True)
def evaluate_phenotypes(target_phenotype, current_phentoype, path_scores, residual_means, residual_variances, hap, index):
    nTraits = len(target_phenotype)
    score = 0
    for i in range(nTraits):
        residual_phenotype = target_phenotype[i] - current_phentoype[i] - path_scores[index, i, hap] - residual_means[index, i,hap]
        score += - residual_phenotype**2/(2 * ( 1 + residual_variances[index, i, hap]))
    return score

@jit(nopython=True, nogil = True)
def numba_add(phenotype, update):
    for i in range(len(phenotype)):
        phenotype[i] += update[i]


# def gen_random_path(nLoci):
#     # Just sample two paths at random.
#     # rec_rate = 1/nLoci
#     assignments = np.full(nLoci, 0, dtype = np.int8)
#     currentHap = np.random.random() < .5
#     switches = np.random.random(nLoci) < rec_rate
#     for i in range(nLoci):
#         if switches[i]:
#             currentHap = 1 - currentHap
#         assignments[i] = currentHap
#     return assignments



# # See numba_distance_speed.py in pythontests for comparison of distance functions.
# # This uses the "binomial trick" from https://blog.smola.org/post/969195661/in-praise-of-the-second-binomial-formula
# @numba.njit
# def mat_distance(reference, ref_sq, target):

#     # ref_sq = np.sum(reference**2, axis = 1)
#     target_sq = 0
#     for i in range(len(target)):
#         target_sq += target[i] ** 2

#     output = np.dot(reference, target)

#     for i in range(reference.shape[0]):
#         output[i] = -(target_sq + ref_sq[i] - 2*output[i])/2

#     return output


# @numba.njit
# def norm_sum_scores(mat): 

#     max_val = 1 # Log of anything between 0-1 will be less than 0. Using 1 as a default.
#     for a in range(len(mat)):
#         if mat[a] > max_val or max_val == 1:
#             max_val = mat[a]
#     for a in range(len(mat)):
#         mat[a] -= max_val

#     # Should flag for better numba-ness.
#     tmp = np.exp(mat)
#     total = 0
#     for a in range(len(mat)):
#         total += tmp[a]

#     log_total = np.log(total) + max_val # Not correcting for number of paths -- if more paths are one way or the other, we're actually okay with that.

#     return log_total




# nTraits = 50
# nLoci = 10

# # path_0 = [[10, 0, 0, 0, 0, 0,  0], 
# #          [20, 0, 0, 0, 0, 0,  0],
# #          [0, 30, 0, 0, 0, 0,  0],
# #          [0,  0, 0, 0, 0, 0,  0],
# #          [0,  0, 0, 0, 0, 0,  0]]
# # path_1 = [[0,  0, 0, 0, 0, 0,  0], 
# #          [0,  0, 0, 0, 0, 0,  0],
# #          [0,  0, 0, 0, 0, 0,  0],
# #          [0,  0, 0, 0, 0, 0, 40],
# #          [0,  0, 0, 0, 0, 0, 50]]

# # path_0 = np.full((nTraits, nLoci), 0, dtype = np.float32)
# np.random.seed(5)
# path_0 = np.random.normal(size = (nTraits, nLoci)).astype(np.float32)/np.sqrt(nLoci)
# path_1 = np.random.normal(size = (nTraits, nLoci)).astype(np.float32)/np.sqrt(nLoci)

# path_scores = np.array([path_0, path_1], dtype = np.float32)
# path_scores = np.transpose(path_scores, (2, 1, 0))

# def gen_random_path(nLoci):
#     # Just sample two paths at random.
#     # rec_rate = 1/nLoci
#     assignments = np.full(nLoci, 0, dtype = np.int8)
#     currentHap = np.random.random() < .5
#     switches = np.random.random(nLoci) < rec_rate
#     for i in range(nLoci):
#         if switches[i]:
#             currentHap = 1 - currentHap
#         assignments[i] = currentHap
#     return assignments

# rec_rate = 1/nLoci

# true_path = gen_random_path(path_scores.shape[0])
# target_phenotype = np.sum((1-true_path[:,None])*path_scores[:,:,0] + true_path[:,None]*path_scores[:,:,1], axis = 0)


# residual_means, residual_variances = set_mean_variance_all(path_scores, rec_rate)

# current_path, current_phentoype = sample_path(target_phenotype, path_scores, residual_means, residual_variances, rec_rate)


# print(true_path)
# print(current_path)

# print(residual_means)

